{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#----------------------------------------------------------------------\n",
    "# Purpose: Train a GBM model. Fetch and traverse the backed tree\n",
    "#----------------------------------------------------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "versionFromGradle='3.20.0',projectVersion='3.20.0.99999',branch='rel-wright',lastCommitHash='17cd2095ef4547f12c3efc122ceba7132a8a8f56',gitDescribe='jenkins-3.20.0.7-6-g17cd2095ef-dirty',compiledOn='2018-09-11 15:07:39',compiledBy='pavel'\n"
     ]
    }
   ],
   "source": [
    "import h2o\n",
    "from h2o.estimators.gbm import H2OGradientBoostingEstimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checking whether there is an H2O instance running at http://localhost:54321. connected.\n",
      "versionFromGradle='3.20.0',projectVersion='3.20.0.99999',branch='rel-wright',lastCommitHash='17cd2095ef4547f12c3efc122ceba7132a8a8f56',gitDescribe='jenkins-3.20.0.7-6-g17cd2095ef-dirty',compiledOn='2018-09-11 15:07:39',compiledBy='pavel'\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td>H2O cluster uptime:</td>\n",
       "<td>29 secs</td></tr>\n",
       "<tr><td>H2O cluster timezone:</td>\n",
       "<td>Europe/Prague</td></tr>\n",
       "<tr><td>H2O data parsing timezone:</td>\n",
       "<td>UTC</td></tr>\n",
       "<tr><td>H2O cluster version:</td>\n",
       "<td>3.20.0.99999</td></tr>\n",
       "<tr><td>H2O cluster version age:</td>\n",
       "<td>21 hours and 17 minutes </td></tr>\n",
       "<tr><td>H2O cluster name:</td>\n",
       "<td>pavel</td></tr>\n",
       "<tr><td>H2O cluster total nodes:</td>\n",
       "<td>1</td></tr>\n",
       "<tr><td>H2O cluster free memory:</td>\n",
       "<td>3.452 Gb</td></tr>\n",
       "<tr><td>H2O cluster total cores:</td>\n",
       "<td>12</td></tr>\n",
       "<tr><td>H2O cluster allowed cores:</td>\n",
       "<td>12</td></tr>\n",
       "<tr><td>H2O cluster status:</td>\n",
       "<td>accepting new members, healthy</td></tr>\n",
       "<tr><td>H2O connection url:</td>\n",
       "<td>http://localhost:54321</td></tr>\n",
       "<tr><td>H2O connection proxy:</td>\n",
       "<td>None</td></tr>\n",
       "<tr><td>H2O internal security:</td>\n",
       "<td>False</td></tr>\n",
       "<tr><td>H2O API Extensions:</td>\n",
       "<td>XGBoost, Algos, Core V3, Core V4</td></tr>\n",
       "<tr><td>Python version:</td>\n",
       "<td>3.6.5 final</td></tr></table></div>"
      ],
      "text/plain": [
       "--------------------------  --------------------------------\n",
       "H2O cluster uptime:         29 secs\n",
       "H2O cluster timezone:       Europe/Prague\n",
       "H2O data parsing timezone:  UTC\n",
       "H2O cluster version:        3.20.0.99999\n",
       "H2O cluster version age:    21 hours and 17 minutes\n",
       "H2O cluster name:           pavel\n",
       "H2O cluster total nodes:    1\n",
       "H2O cluster free memory:    3.452 Gb\n",
       "H2O cluster total cores:    12\n",
       "H2O cluster allowed cores:  12\n",
       "H2O cluster status:         accepting new members, healthy\n",
       "H2O connection url:         http://localhost:54321\n",
       "H2O connection proxy:\n",
       "H2O internal security:      False\n",
       "H2O API Extensions:         XGBoost, Algos, Core V3, Core V4\n",
       "Python version:             3.6.5 final\n",
       "--------------------------  --------------------------------"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h2o.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parse progress: |█████████████████████████████████████████████████████████| 100%\n"
     ]
    }
   ],
   "source": [
    "from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.\n",
    "\n",
    "air = h2o.import_file(_locate(\"smalldata/airlines/allyears2k_headers.zip\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Create new GBM model with limited number of trees and limited depth.\n",
    "gbm_air_model = H2OGradientBoostingEstimator(ntrees = 3, max_depth = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gbm Model Build progress: |███████████████████████████████████████████████| 100%\n"
     ]
    }
   ],
   "source": [
    "gbm_air_model.train(x = [\"Origin\", \"Distance\", \"UniqueCarrier\"], y = \"IsDepDelayed\", training_frame = air)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Fetch the very first tree (out of 3 total trees)\n",
    "from h2o.tree import H2OTree, H2ONode\n",
    "tree = H2OTree(gbm_air_model, 0, \"NO\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Print number of nodes in a tree\n",
    "len(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node ID 0 \n",
      "Left child node ID = 1\n",
      "Right child node ID = 2\n",
      "\n",
      "Splits on column Origin\n",
      "  - Categorical levels going to the left node: ['ABE', 'ABQ', 'ACY', 'AUS', 'AVP', 'BHM', 'BIL', 'BNA', 'BOI', 'BOS', 'BUF', 'BUR', 'CAE', 'CHS', 'CLE', 'CLT', 'COS', 'CRP', 'CRW', 'DCA', 'DEN', 'DSM', 'DTW', 'EGE', 'EWR', 'EYW', 'GEG', 'GNV', 'GSO', 'HNL', 'HRL', 'IAD', 'IAH', 'JAX', 'JFK', 'KOA', 'LAN', 'LBB', 'LIH', 'MAF', 'MDT', 'MEM', 'MHT', 'MKE', 'MLB', 'MRY', 'MSP', 'MSY', 'MYR', 'OAK', 'OGG', 'OKC', 'PHF', 'PHX', 'PWM', 'RDU', 'SAN', 'SAV', 'SBN', 'SDF', 'SJC', 'SLC', 'SMF', 'STL', 'STT', 'TLH', 'TRI', 'TUL', 'TUS', 'TYS']\n",
      "  - Categorical levels going to the right node: ['ALB', 'AMA', 'ANC', 'ATL', 'BDL', 'BGM', 'BTV', 'BWI', 'CHO', 'CMH', 'CVG', 'DAL', 'DAY', 'DFW', 'ELP', 'ERI', 'FLL', 'GRR', 'HOU', 'HPN', 'ICT', 'IND', 'ISP', 'JAN', 'LAS', 'LAX', 'LEX', 'LGA', 'LIT', 'LYH', 'MCI', 'MCO', 'MDW', 'MFR', 'MIA', 'OMA', 'ONT', 'ORD', 'ORF', 'PBI', 'PDX', 'PHL', 'PIT', 'PSP', 'PVD', 'RIC', 'RNO', 'ROA', 'ROC', 'RSW', 'SAT', 'SCK', 'SEA', 'SFO', 'SJU', 'SNA', 'SRQ', 'STX', 'SWF', 'SYR', 'TPA', 'UCA']\n",
      "\n",
      "NA values go to the RIGHT\n"
     ]
    }
   ],
   "source": [
    "# Show description of root node\n",
    "tree.root_node.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Leaf node ID 18. Predicted value at leaf node is 0.044361357 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Show description of a terminal node\n",
    "tree.root_node.left_child.right_child.right_child.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Origin\n",
      "['ABE', 'ABQ', 'ACY', 'AUS', 'AVP', 'BHM', 'BIL', 'BNA', 'BOI', 'BOS', 'BUF', 'BUR', 'CAE', 'CHS', 'CLE', 'CLT', 'COS', 'CRP', 'CRW', 'DCA', 'DEN', 'DSM', 'DTW', 'EGE', 'EWR', 'EYW', 'GEG', 'GNV', 'GSO', 'HNL', 'HRL', 'IAD', 'IAH', 'JAX', 'JFK', 'KOA', 'LAN', 'LBB', 'LIH', 'MAF', 'MDT', 'MEM', 'MHT', 'MKE', 'MLB', 'MRY', 'MSP', 'MSY', 'MYR', 'OAK', 'OGG', 'OKC', 'PHF', 'PHX', 'PWM', 'RDU', 'SAN', 'SAV', 'SBN', 'SDF', 'SJC', 'SLC', 'SMF', 'STL', 'STT', 'TLH', 'TRI', 'TUL', 'TUS', 'TYS']\n",
      "['ALB', 'AMA', 'ANC', 'ATL', 'BDL', 'BGM', 'BTV', 'BWI', 'CHO', 'CMH', 'CVG', 'DAL', 'DAY', 'DFW', 'ELP', 'ERI', 'FLL', 'GRR', 'HOU', 'HPN', 'ICT', 'IND', 'ISP', 'JAN', 'LAS', 'LAX', 'LEX', 'LGA', 'LIT', 'LYH', 'MCI', 'MCO', 'MDW', 'MFR', 'MIA', 'OMA', 'ONT', 'ORD', 'ORF', 'PBI', 'PDX', 'PHL', 'PIT', 'PSP', 'PVD', 'RIC', 'RNO', 'ROA', 'ROC', 'RSW', 'SAT', 'SCK', 'SEA', 'SFO', 'SJU', 'SNA', 'SRQ', 'STX', 'SWF', 'SYR', 'TPA', 'UCA']\n",
      "nan\n",
      "RIGHT\n"
     ]
    }
   ],
   "source": [
    "# Show raw attributes of a tree node\n",
    "print(tree.root_node.split_feature)\n",
    "print(tree.root_node.left_levels)\n",
    "print(tree.root_node.right_levels)\n",
    "print(tree.root_node.threshold)\n",
    "print(tree.root_node.na_direction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 3, 5, 7, 9, 11, 13, -1, -1, -1, -1, -1, -1, -1, -1]\n",
      "[2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1]\n",
      "[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]\n",
      "['Origin', 'Origin', 'UniqueCarrier', 'UniqueCarrier', 'UniqueCarrier', 'UniqueCarrier', 'Origin', None, None, None, None, None, None, None, None]\n"
     ]
    }
   ],
   "source": [
    "# Print some of the raw attributes of a tree available\n",
    "print(tree.left_children)\n",
    "print(tree.right_children)\n",
    "print(tree.thresholds)\n",
    "print(tree.features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}